# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

# Owner(s): ["oncall: quantization"]
import copy
import operator
import unittest
from typing import Any, Optional, Tuple

import torch
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
from torch.ao.quantization.qconfig import (
    default_per_channel_symmetric_qnnpack_qat_qconfig,
    default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import prepare_qat_fx
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_quantization import (
    NodeSpec as ns,
)
from torch.testing._internal.common_quantization import (
    QuantizationTestCase,
    skip_if_no_torchvision,
    skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import TEST_XPU, run_tests

from torchao.quantization.pt2e import (
    FusedMovingAvgObsFakeQuantize,
    MovingAverageMinMaxObserver,
    MovingAveragePerChannelMinMaxObserver,
    default_fake_quant,
)
from torchao.quantization.pt2e.quantize_pt2e import (
    convert_pt2e,
    prepare_pt2e,
    prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import (
    DerivedQuantizationSpec,
    QuantizationAnnotation,
    QuantizationSpec,
    Quantizer,
    SharedQuantizationSpec,
)
from torchao.testing.pt2e._xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torchao.utils import get_current_accelerator_device, torch_version_at_least

_DEVICE = get_current_accelerator_device()


class PT2EQATTestCase(QuantizationTestCase):
    """
    Base QuantizationTestCase for PT2E QAT with some helper methods.
    """

    class _BaseConvBnModel(torch.nn.Module):
        def __init__(
            self,
            conv_class: type[torch.nn.Module],
            bn_class: type[torch.nn.Module],
            has_conv_bias: bool,
            has_bn: bool,
            has_relu: bool,
            **conv_kwargs,
        ):
            super().__init__()
            conv_kwargs.setdefault("in_channels", 3)
            conv_kwargs.setdefault("out_channels", 3)
            conv_kwargs.setdefault("kernel_size", 3)
            conv_kwargs.setdefault("bias", has_conv_bias)
            self.conv = conv_class(**conv_kwargs)
            self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
            self.relu = torch.nn.ReLU() if has_relu else None

        def forward(self, x):
            x = self.conv(x)
            if self.bn is not None:
                x = self.bn(x)
            if self.relu is not None:
                x = self.relu(x)
            return x

    def _get_conv_bn_model(
        self,
        has_conv_bias: bool = True,
        has_bn: bool = True,
        has_relu: bool = False,
        transpose: bool = False,
        **conv_kwargs,
    ):
        """
        Return an instance of a simple test model containing the
        conv[-bn][-relu] pattern. By default, this returns a
        conv-bn model with conv bias.
        """
        return self._BaseConvBnModel(
            self.conv_transpose_class if transpose else self.conv_class,
            self.bn_class,
            has_conv_bias,
            has_bn,
            has_relu,
            **conv_kwargs,
        )

    def _verify_symmetric_xnnpack_qat_numerics(
        self,
        model: torch.nn.Module,
        example_inputs: tuple[Any, ...],
    ):
        self._verify_symmetric_xnnpack_qat_numerics_helper(
            model,
            example_inputs,
            is_per_channel=True,
        )
        self._verify_symmetric_xnnpack_qat_numerics_helper(
            model,
            example_inputs,
            is_per_channel=False,
        )

    def _verify_symmetric_xnnpack_qat_numerics_helper(
        self,
        model: torch.nn.Module,
        example_inputs: tuple[Any, ...],
        is_per_channel: bool,
        verify_convert: bool = True,
    ):
        """
        Helper method to verify that the QAT numerics for PT2E quantization match those of
        FX graph mode quantization for symmetric qnnpack.
        """
        # resetting dynamo cache
        torch._dynamo.reset()
        MANUAL_SEED = 100

        # PT2 export

        model_pt2e = copy.deepcopy(model)
        quantizer = XNNPACKQuantizer()
        quantizer.set_global(
            get_symmetric_quantization_config(
                is_per_channel=is_per_channel, is_qat=True
            )
        )
        model_pt2e = torch.export.export(
            model_pt2e, example_inputs, strict=True
        ).module()
        model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
        torch.manual_seed(MANUAL_SEED)
        after_prepare_result_pt2e = model_pt2e(*example_inputs)

        model_fx = copy.deepcopy(model)
        if is_per_channel:
            default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig
        else:
            default_qconfig = default_symmetric_qnnpack_qat_qconfig
        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
        backend_config = get_qnnpack_backend_config()
        model_fx = prepare_qat_fx(
            model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
        )
        torch.manual_seed(MANUAL_SEED)
        after_prepare_result_fx = model_fx(*example_inputs)

        # Verify that numerics match
        print("model pt2e:", model_pt2e)
        print("model fx:", model_fx)
        print("diff:", after_prepare_result_pt2e - after_prepare_result_fx)
        self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

        if verify_convert:
            from torch.ao.quantization.quantize_pt2e import (
                _convert_to_reference_decomposed_fx,
            )

            # We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
            torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
            model_pt2e = convert_pt2e(model_pt2e)
            quant_result_pt2e = model_pt2e(*example_inputs)
            model_fx.eval()
            model_fx = _convert_to_reference_decomposed_fx(
                model_fx,
                backend_config=backend_config,
            )
            quant_result_fx = model_fx(*example_inputs)
            print("converted model pt2e:", model_pt2e)
            print("onverted model fx:", model_fx)
            print("diff:", quant_result_pt2e - quant_result_fx)
            self.assertEqual(quant_result_pt2e, quant_result_fx)

    def _verify_symmetric_xnnpack_qat_graph(
        self,
        m: torch.fx.GraphModule,
        example_inputs: tuple[Any, ...],
        has_relu: bool,
        has_bias: bool = True,
        is_cuda: bool = False,
        expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
        # TODO: set this to true by default
        verify_convert: bool = False,
    ):
        self._verify_symmetric_xnnpack_qat_graph_helper(
            m,
            example_inputs,
            is_per_channel=True,
            has_relu=has_relu,
            has_bias=has_bias,
            is_cuda=is_cuda,
            expected_conv_literal_args=expected_conv_literal_args,
            verify_convert=verify_convert,
        )
        self._verify_symmetric_xnnpack_qat_graph_helper(
            m,
            example_inputs,
            is_per_channel=False,
            has_relu=has_relu,
            has_bias=has_bias,
            is_cuda=is_cuda,
            expected_conv_literal_args=expected_conv_literal_args,
            verify_convert=verify_convert,
        )

    def _verify_symmetric_xnnpack_qat_graph_helper(
        self,
        m: torch.fx.GraphModule,
        example_inputs: tuple[Any, ...],
        is_per_channel: bool,
        has_relu: bool,
        has_bias: bool = True,
        is_cuda: bool = False,
        expected_conv_literal_args: Optional[tuple[Any, ...]] = None,
        verify_convert: bool = False,
    ):
        """
        Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
        with fake quantizes inserted into the correct places.
        # TODO: also verify that metadata is copied over to the new nodes.
        """
        m = copy.deepcopy(m)
        quantizer = XNNPACKQuantizer()
        quantizer.set_global(
            get_symmetric_quantization_config(is_per_channel, is_qat=True)
        )
        m = torch.export.export(m, example_inputs, strict=True).module()
        m = prepare_qat_pt2e(m, quantizer)
        m(*example_inputs)

        # Verify: getitem output activation fake quantize
        output_node = list(m.graph.nodes)[-1]
        output_fq_node = output_node.args[0][0]
        self.assertTrue(output_fq_node.target.startswith("activation_post_process_"))
        output_fq_mod = getattr(m, output_fq_node.target)
        self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize)
        self.assertEqual(
            type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver
        )
        self.assertEqual(output_fq_mod.dtype, torch.int8)
        self.assertEqual(output_fq_mod.quant_min, -128)
        self.assertEqual(output_fq_mod.quant_max, 127)

        # Verify: getitem(bn, 0) or relu(getitem(bn, 0))
        if has_relu:
            relu_node = output_fq_node.args[0]
            bn_node = relu_node.args[0]
            self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
        else:
            relu_node = None
            bn_node = output_fq_node.args[0]

        # The relu node takes in the output of bn.
        # See NOTE [training ir has no getitem for bn node].
        self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)

        # Verify: conv / scale_factor.reshape [+ bias.reshape]
        if has_bias:
            add_bias_node = bn_node.args[0]
            (div_scale_factor_node, bias_reshape_node) = add_bias_node.args
            self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor)
            self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default)
        else:
            div_scale_factor_node = bn_node.args[0]
        (conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
        conv_op = conv_node.target
        self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
        self.assertTrue(_is_conv_node(conv_node))
        self.assertEqual(
            scale_factor_reshape_node.target, torch.ops.aten.reshape.default
        )

        # Verify: conv literal args
        if expected_conv_literal_args is not None:
            assert len(expected_conv_literal_args) == 6, (
                "wrong num conv args, bad test setup"
            )
            for i in range(6):
                if i + 3 < len(conv_node.args):
                    self.assertEqual(
                        conv_node.args[i + 3], expected_conv_literal_args[i]
                    )

        # Verify: conv input activation fake quantize
        conv_input_fq_node = conv_node.args[0]
        conv_input_node = conv_input_fq_node.args[0]
        self.assertTrue(
            conv_input_fq_node.target.startswith("activation_post_process_")
        )
        conv_input_fq_mod = getattr(m, conv_input_fq_node.target)
        self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize)
        self.assertEqual(
            type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver
        )
        self.assertEqual(conv_input_fq_mod.dtype, torch.int8)
        self.assertEqual(conv_input_fq_mod.quant_min, -128)
        self.assertEqual(conv_input_fq_mod.quant_max, 127)
        self.assertTrue(conv_input_node.op, "placeholder")

        # Verify: conv weight fake quantize
        conv_weight_fq_node = conv_node.args[1]
        self.assertTrue(
            conv_weight_fq_node.target.startswith("activation_post_process_")
        )
        conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target)
        if is_per_channel:
            expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver
        else:
            expected_weight_observer_type = MovingAverageMinMaxObserver
        self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize)
        self.assertEqual(
            type(conv_weight_fq_mod.activation_post_process),
            expected_weight_observer_type,
        )
        self.assertEqual(conv_weight_fq_mod.dtype, torch.int8)
        self.assertEqual(conv_weight_fq_mod.quant_min, -127)
        self.assertEqual(conv_weight_fq_mod.quant_max, 127)

        # Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias)
        zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
        mul_weight_scale_factor_node = conv_weight_fq_node.args[0]
        (
            conv_weight_fq_node,
            scale_factor_reshape_node,
        ) = mul_weight_scale_factor_node.args
        if has_bias:
            self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
        else:
            self.assertTrue(zero_bias_node is None)
        self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor)
        self.assertEqual(
            scale_factor_reshape_node.target, torch.ops.aten.reshape.default
        )

        # Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps)
        scale_factor_node = scale_factor_reshape_node.args[0]
        (bn_weight_node, sqrt_node) = scale_factor_node.args
        bn_running_var_add_node = sqrt_node.args[0]
        (bn_running_var_node, eps) = bn_running_var_add_node.args
        self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
        self.assertTrue("bn.weight" in bn_weight_node.target)
        self.assertTrue("bn.running_var" in bn_running_var_node.target)
        self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
        self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
        self.assertEqual(eps, 1e-5)

        # Optionally check the converted graph
        if verify_convert:
            m = convert_pt2e(m)
            m(*example_inputs)

            if is_per_channel:
                conv_weight_dq_op = (
                    torch.ops.quantized_decomposed.dequantize_per_channel.default
                )
                node_occurrence = {
                    ns.call_function(
                        torch.ops.quantized_decomposed.quantize_per_tensor.default
                    ): 2,
                    ns.call_function(
                        torch.ops.quantized_decomposed.dequantize_per_tensor.default
                    ): 2,
                    ns.call_function(
                        torch.ops.quantized_decomposed.dequantize_per_channel.default
                    ): 1,
                }
            else:
                conv_weight_dq_op = (
                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
                )
                node_occurrence = {
                    ns.call_function(
                        torch.ops.quantized_decomposed.quantize_per_tensor.default
                    ): 2,
                    ns.call_function(
                        torch.ops.quantized_decomposed.dequantize_per_tensor.default
                    ): 3,
                }
            node_list = [
                ns.call_function(
                    torch.ops.quantized_decomposed.quantize_per_tensor.default
                ),
                ns.call_function(
                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
                ),
                ns.call_function(conv_weight_dq_op),
                ns.call_function(conv_op),
                ns.call_function(
                    torch.ops.quantized_decomposed.quantize_per_tensor.default
                ),
                ns.call_function(
                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
                ),
            ]

            self.checkGraphModuleNodes(
                m,
                expected_node_list=node_list,
                expected_node_occurrence=node_occurrence,
            )


@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
    """
    Base TestCase to be used for all conv-bn[-relu] fusion patterns.
    """

    # TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
    # Otherwise it fails with the following error:
    #   torch._dynamo.exc.InternalTorchDynamoError:
    #   'QuantizationConfig' object has no attribute '__bool__'

    def setUp(self):
        # NB: Skip the test if this is a base class, this is to handle the test
        # discovery logic in buck which finds and runs all tests here including
        # the base class which we don't want to run
        if self.id() and "_Base" in self.id():
            self.skipTest("Skipping test running from base class")

    def test_qat_conv_no_bias(self):
        m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True)
        m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False)
        self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs)
        self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs)

    def test_qat_conv_bn_fusion(self):
        m = self._get_conv_bn_model()
        self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

    @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
    def test_qat_conv_bn_fusion_cuda(self):
        m = self._get_conv_bn_model().to(_DEVICE)
        example_inputs = (self.example_inputs[0].to(_DEVICE),)
        self._verify_symmetric_xnnpack_qat_graph(
            m,
            example_inputs,
            has_relu=False,
            is_cuda=True,
        )
        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

    def test_qat_conv_bn_fusion_literal_args(self):
        class M(torch.nn.Module):
            def __init__(self, conv_class, bn_class):
                super().__init__()
                self.conv = conv_class(3, 3, 3, stride=2, padding=4)
                self.bn = bn_class(3)

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        assert self.dim in [1, 2]
        if self.dim == 1:
            # stride, padding, dilation, transposed, output_padding, groups
            conv_args = ((2,), (4,), (1,), False, (0,), 1)
            example_inputs = (torch.randn(1, 3, 5),)
        else:
            # stride, padding, dilation, transposed, output_padding, groups
            conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
            example_inputs = (torch.randn(1, 3, 5, 5),)

        m = M(self.conv_class, self.bn_class)

        self._verify_symmetric_xnnpack_qat_graph(
            m,
            example_inputs,
            has_relu=False,
            expected_conv_literal_args=conv_args,
        )
        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

    def test_qat_conv_bn_fusion_no_conv_bias(self):
        class M2(torch.nn.Module):
            """
            Mixed conv + BN with and without conv bias.
            """

            def __init__(self, conv_class, bn_class):
                super().__init__()
                self.conv1 = conv_class(3, 3, 3, bias=False)
                self.bn1 = bn_class(3)
                self.conv2 = conv_class(3, 3, 3, bias=True)
                self.bn2 = bn_class(3)

            def forward(self, x):
                x = self.conv1(x)
                x = self.bn1(x)
                x = self.conv2(x)
                x = self.bn2(x)
                return x

        m1 = self._get_conv_bn_model(has_conv_bias=False)
        m2 = M2(self.conv_class, self.bn_class)

        assert self.dim in [1, 2]
        if self.dim == 1:
            example_inputs = (torch.randn(3, 3, 5),)
        else:
            example_inputs = (torch.randn(3, 3, 5, 5),)

        self._verify_symmetric_xnnpack_qat_graph(
            m1,
            example_inputs,
            has_relu=False,
            has_bias=False,
        )
        self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs)
        self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs)

    def test_qat_conv_bn_relu_fusion(self):
        m = self._get_conv_bn_model(has_relu=True)
        self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

    @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
    def test_qat_conv_bn_relu_fusion_cuda(self):
        m = self._get_conv_bn_model(has_relu=True).to(_DEVICE)
        example_inputs = (self.example_inputs[0].to(_DEVICE),)
        self._verify_symmetric_xnnpack_qat_graph(
            m,
            example_inputs,
            has_relu=True,
            is_cuda=True,
        )
        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

    def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
        m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True)
        self._verify_symmetric_xnnpack_qat_graph(
            m,
            self.example_inputs,
            has_relu=True,
            has_bias=False,
        )
        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

    def test_qat_inplace_add_relu(self):
        class M(torch.nn.Module):
            def __init__(self, conv_class):
                super().__init__()
                self.conv = conv_class(1, 1, 1)
                self.relu = torch.nn.ReLU(inplace=True)

            def forward(self, x):
                x0 = x
                x = self.conv(x)
                x += x0
                x = self.relu(x)
                return x

        assert self.dim in [1, 2]
        if self.dim == 1:
            example_inputs = (torch.randn(1, 1, 3),)
        else:
            example_inputs = (torch.randn(1, 1, 3, 3),)

        m = M(self.conv_class)
        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

    def test_qat_update_shared_qspec(self):
        """
        Test the case where nodes used in SharedQuantizationSpec were replaced
        during QAT subgraph rewriting.
        """

        class M(torch.nn.Module):
            def __init__(self, conv_class, bn_class):
                super().__init__()
                self.conv = conv_class(3, 3, 3)
                self.bn = bn_class(3)
                self.hardtanh = torch.nn.Hardtanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                x = self.hardtanh(x)
                return x

        m = M(self.conv_class, self.bn_class)
        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)

    def test_qat_preserve_source_fn_stack(self):
        """
        Test whether `source_fn_stack` is preserved after QAT fusion.
        """

        class M(torch.nn.Module):
            def __init__(self, conv_class, bn_class, backbone):
                super().__init__()
                self.conv = conv_class(5, 3, 3)
                self.bn = bn_class(3)
                self.relu = torch.nn.ReLU()
                self.backbone = backbone

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                x = self.relu(x)
                x = self.backbone(x)
                return x

        assert self.dim in [1, 2]
        if self.dim == 1:
            example_inputs = (torch.randn(1, 5, 10),)
        else:
            example_inputs = (torch.randn(1, 5, 10, 10),)

        # QAT prepare + convert
        backbone = self._get_conv_bn_model(has_relu=True)
        m = M(self.conv_class, self.bn_class, backbone)
        quantizer = XNNPACKQuantizer()
        quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
        m = torch.export.export(m, example_inputs, strict=True).module()
        m = prepare_qat_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)

        # Extract the conv and relu nodes (bn was folded into conv)
        first_conv, first_relu, second_conv, second_relu = None, None, None, None
        for n in m.graph.nodes:
            if n.target == torch.ops.aten.relu.default:
                if first_relu is None:
                    assert first_conv is None, "bad test setup"
                    first_relu = n
                    first_conv = n.args[0]
                else:
                    assert second_conv is None, "bad test setup"
                    second_relu = n
                    second_conv = n.args[0]

        # Extract the conv weight and bias nodes
        def get_conv_weight_and_bias(conv_node: torch.fx.Node):
            weight_dq_node = conv_node.args[1]
            qweight_node = weight_dq_node.args[0]
            bias_node = conv_node.args[2]
            assert isinstance(qweight_node, torch.fx.Node)
            assert isinstance(bias_node, torch.fx.Node)
            return (qweight_node, bias_node)

        _, first_conv_bias = get_conv_weight_and_bias(first_conv)
        _, second_conv_bias = get_conv_weight_and_bias(second_conv)

        # Assert that each set of conv, conv weight, and conv bias are in the same partition
        def get_source_fn(node: torch.fx.Node):
            # E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)]
            return node.meta["source_fn_stack"][0][0]

        # we don't preserve this is quantized weight currently since it's folded
        # but user can attach "quantization_tag" to the node and it will be preserved
        # self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight))
        # self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight))

        self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias))
        self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias))

        # Assert that different sets of convs and relus have different partitions
        self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu))
        self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv))
        self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu))
        self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu))

    def test_qat_conv_bn_bias_derived_qspec(self):
        m = self._get_conv_bn_model()
        example_inputs = self.example_inputs
        m = torch.export.export(m, example_inputs, strict=True).module()
        quantizer = ConvBnDerivedBiasQuantizer()
        m = prepare_qat_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        m(*example_inputs)

        # Assert that both weight and bias are quantized
        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
        weight_dq = conv_node.args[1]
        bias_dq = conv_node.args[2]
        self.assertEqual(
            weight_dq.target,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
        )
        self.assertEqual(
            bias_dq.target,
            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
        )
        weight_getattr = weight_dq.args[0]
        bias_getattr = bias_dq.args[0]
        self.assertEqual(
            weight_getattr.op,
            "get_attr",
        )
        self.assertEqual(
            bias_getattr.op,
            "get_attr",
        )

        # Assert that bias scale = weight scale * input scale
        input_dq = conv_node.args[0]
        input_scale = input_dq.args[1]
        bias_scale = bias_dq.args[1]
        weight_scale = weight_dq.args[1]
        self.assertEqual(bias_scale, input_scale * weight_scale)

        # Assert that args for the bias' quantize and dequantize ops
        # are copied correctly after subgraph rewriting
        (bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:]
        self.assertEqual(bias_qmin, -(2**31))
        self.assertEqual(bias_qmax, 2**31 - 1)
        self.assertEqual(bias_dtype, torch.int32)

    def test_qat_per_channel_weight_custom_dtype(self):
        m = self._get_conv_bn_model()
        example_inputs = self.example_inputs
        m = torch.export.export(m, example_inputs, strict=True).module()
        quantizer = ConvBnInt32WeightQuantizer()
        m = prepare_qat_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        m(*example_inputs)

        # Assert that conv weight is quantized per channel
        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
        weight_dq = conv_node.args[1]
        self.assertEqual(
            weight_dq.target,
            torch.ops.quantized_decomposed.dequantize_per_channel.default,
        )
        weight_getattr = weight_dq.args[0]
        self.assertEqual(
            weight_getattr.op,
            "get_attr",
        )

        # Assert that args for the weight's dequantize ops
        # are copied correctly after subgraph rewriting
        (dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:]
        self.assertEqual(dq_axis, 0)
        self.assertEqual(dq_qmin, 0)
        self.assertEqual(dq_qmax, 2**31 - 1)
        self.assertEqual(dq_dtype, torch.int32)

    def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
        # Use different in/out channel sizes to test if conv weight is
        # properly transposed in QAT pattern
        m = self._get_conv_bn_model(
            has_relu=has_relu,
            transpose=True,
            in_channels=3,
            out_channels=5,
            kernel_size=3,
        )
        self._verify_symmetric_xnnpack_qat_graph(
            m,
            self.example_inputs,
            has_relu=has_relu,
            verify_convert=True,
        )

    def test_qat_conv_transpose_bn(self):
        self._do_test_qat_conv_transpose_bn(has_relu=False)

    def test_qat_conv_transpose_bn_relu(self):
        self._do_test_qat_conv_transpose_bn(has_relu=True)

    def test_qat_conv_bn_per_channel_weight_bias(self):
        m = self._get_conv_bn_model()
        example_inputs = self.example_inputs
        m = torch.export.export(m, example_inputs, strict=True).module()
        quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
        m = prepare_qat_pt2e(m, quantizer)
        m(*example_inputs)
        m = convert_pt2e(m)
        m(*example_inputs)

        # Expected graph:
        #      x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output
        #  weight -> q_channel -> dq_channel /
        #    bias -> q_channel -> dq_channel /

        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
        conv_op = conv_node.target
        conv_weight_dq_op = (
            torch.ops.quantized_decomposed.dequantize_per_channel.default
        )
        node_occurrence = {
            ns.call_function(
                torch.ops.quantized_decomposed.quantize_per_tensor.default
            ): 2,
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_tensor.default
            ): 2,
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_channel.default
            ): 2,
        }
        node_list = [
            ns.call_function(
                torch.ops.quantized_decomposed.quantize_per_tensor.default
            ),
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_tensor.default
            ),
            ns.call_function(conv_weight_dq_op),
            ns.call_function(conv_weight_dq_op),
            ns.call_function(conv_op),
            ns.call_function(
                torch.ops.quantized_decomposed.quantize_per_tensor.default
            ),
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_tensor.default
            ),
        ]
        self.checkGraphModuleNodes(
            m,
            expected_node_list=node_list,
            expected_node_occurrence=node_occurrence,
        )

    def test_fold_bn_erases_bn_node(self):
        """
        Ensure the BN node is erased from the graph after folding
        it into conv in `convert_pt2e` even in train mode.
        """
        m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
        m = torch.export.export(m, self.example_inputs, strict=True).module()
        quantizer = XNNPACKQuantizer()
        quantizer.set_global(
            get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
        )
        m = prepare_qat_pt2e(m, quantizer)
        m = convert_pt2e(m)
        (conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m)
        self.assertTrue(conv_node is not None)
        self.assertTrue(bn_node is None)


@skipIfNoQNNPACK
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
    dim = 1
    example_inputs = (torch.randn(1, 3, 5),)
    conv_class = torch.nn.Conv1d
    conv_transpose_class = torch.nn.ConvTranspose1d
    bn_class = torch.nn.BatchNorm1d


@skipIfNoQNNPACK
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
    dim = 2
    example_inputs = (torch.randn(1, 3, 5, 5),)
    conv_class = torch.nn.Conv2d
    conv_transpose_class = torch.nn.ConvTranspose2d
    bn_class = torch.nn.BatchNorm2d

    def test_qat_shared_qspec(self):
        """
        Test that nodes used in the keys of `input_qspec_map` refer to the
        new nodes after QAT fusion, not the old nodes that no longer exist.
        """
        m = DoubleConvBnModel()
        example_inputs = (torch.randn(1, 3, 5, 5),)
        m = torch.export.export_for_training(m, example_inputs, strict=True).module()
        old_nodes = set(m.graph.nodes)
        m = prepare_qat_pt2e(m, DoubleConvBnQuantizer())
        new_nodes = set(m.graph.nodes)
        old_nodes = old_nodes - new_nodes
        assert old_nodes.isdisjoint(new_nodes), "bad test setup"
        assert len(old_nodes) == 4, (
            f"bad test setup, old nodes should have 2 convs and 2 bns: {old_nodes}"
        )

        # first, gather a list of nodes to check from input and output qspecs
        nodes_to_check = set()
        for n in m.graph.nodes:
            annotations = n.meta.get("quantization_annotation")
            if annotations is None:
                continue
            nodes_to_check.update(list(annotations.input_qspec_map.keys()))
            for qspec in list(annotations.input_qspec_map.values()) + [
                annotations.output_qspec
            ]:
                if isinstance(qspec, SharedQuantizationSpec):
                    if isinstance(qspec.edge_or_node, torch.fx.Node):
                        nodes_to_check.add(qspec.edge_or_node)
                    else:
                        (src, dest) = qspec.edge_or_node
                        nodes_to_check.update([src, dest])

        # assert that none of the nodes refer to old nodes
        self.assertEqual(len(nodes_to_check), 5)
        num_batch_norm_nodes_checked = 0
        for n in nodes_to_check:
            if n.target == torch.ops.aten.batch_norm.default:
                num_batch_norm_nodes_checked += 1
            self.assertTrue(
                n not in old_nodes,
                f"found old node {n} in qspec, old nodes: {old_nodes}",
            )
            self.assertTrue(
                n in new_nodes, f"found node {n} in qspec not in new nodes: {new_nodes}"
            )
        assert num_batch_norm_nodes_checked == 2, (
            f"bad test setup, didn't check 2 bns, only checked these: {nodes_to_check}"
        )


def _is_conv_node(n: torch.fx.Node):
    return n.op == "call_function" and n.target in [
        torch.ops.aten.conv1d.default,
        torch.ops.aten.conv2d.default,
        torch.ops.aten.conv_transpose1d,
        torch.ops.aten.conv_transpose1d.default,
        torch.ops.aten.conv_transpose2d,
        torch.ops.aten.conv_transpose2d.input,
    ]


def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule):
    """
    Return a 3-tuple of (conv, bn, getitem) nodes from the graph.
    """
    model.graph.eliminate_dead_code()
    model.recompile()
    conv_node = None
    bn_node = None
    getitem_node = None
    for n in model.graph.nodes:
        if _is_conv_node(n):
            conv_node = n
        if n.target in (
            torch.ops.aten._native_batch_norm_legit.default,
            torch.ops.aten.batch_norm.default,
        ):
            bn_node = n
        if n.target == operator.getitem:
            getitem_node = n
    assert conv_node is not None, "bad test setup"
    return (conv_node, bn_node, getitem_node)


class DoubleConvBnModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 3, 3, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(3)
        self.conv2 = torch.nn.Conv2d(3, 3, 3, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x2 = self.conv2(x)
        x2 = self.bn2(x2)
        return torch.cat((x1, x2))


class DoubleConvBnQuantizer(Quantizer):
    """
    Dummy quantizer that a model with double conv-bn, followed by a torch.cat
    of the two conv-bns.
    """

    def __init__(self):
        super().__init__()
        self.act_qspec = QuantizationSpec(
            dtype=torch.uint8,
            quant_min=0,
            quant_max=255,
            qscheme=torch.per_tensor_affine,
            observer_or_fake_quant_ctr=default_fake_quant,
        )
        self.weight_qspec = QuantizationSpec(
            dtype=torch.uint8,
            quant_min=0,
            quant_max=255,
            qscheme=torch.per_tensor_affine,
            observer_or_fake_quant_ctr=default_fake_quant,
        )

    def _get_all_nodes(self, model: torch.nn.Module) -> Tuple:
        """
        Return a 5-tuple of (conv1, bn1, conv2, bn2, cat) nodes.
        """
        conv1, bn1, conv2, bn2, cat = None, None, None, None, None
        for n in model.graph.nodes:
            if _is_conv_node(n):
                if conv1 is None:
                    conv1 = n
                else:
                    conv2 = n
            if n.target == torch.ops.aten.batch_norm.default:
                if bn1 is None:
                    bn1 = n
                else:
                    bn2 = n
            if n.target == torch.ops.aten.cat.default:
                cat = n
        assert conv1 is not None and bn1 is not None, "bad test setup"
        assert conv2 is not None and bn2 is not None, "bad test setup"
        assert cat is not None, "bad test setup"
        return (conv1, bn1, conv2, bn2, cat)

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        (conv1, bn1, conv2, bn2, cat) = self._get_all_nodes(model)
        conv1.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                conv1.args[0]: self.act_qspec,
                conv1.args[1]: self.weight_qspec,
            },
            _annotated=True,
        )
        bn1.meta["quantization_annotation"] = QuantizationAnnotation(
            output_qspec=self.act_qspec,
            _annotated=True,
        )

        conv2.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                conv2.args[0]: self.act_qspec,
                conv2.args[1]: self.weight_qspec,
            },
            _annotated=True,
        )
        bn2.meta["quantization_annotation"] = QuantizationAnnotation(
            output_qspec=self.act_qspec,
            _annotated=True,
        )
        cat.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                bn1: SharedQuantizationSpec(bn1),
                bn2: SharedQuantizationSpec(bn2),
            },
            output_qspec=self.act_qspec,
            _annotated=True,
        )
        return model

    def validate(self, model: torch.fx.GraphModule):
        pass


class ConvBnInt32WeightQuantizer(Quantizer):
    """
    Dummy quantizer that annotates conv bn in such a way that the weights
    are quantized per channel to int32.
    """

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
        act_qspec = QuantizationSpec(
            dtype=torch.uint8,
            quant_min=0,
            quant_max=255,
            qscheme=torch.per_tensor_affine,
            observer_or_fake_quant_ctr=default_fake_quant,
        )
        weight_qspec = QuantizationSpec(
            dtype=torch.int32,
            quant_min=0,
            quant_max=2**31 - 1,
            qscheme=torch.per_channel_affine,
            observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
                observer=MovingAveragePerChannelMinMaxObserver,
            ),
        )
        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                conv_node.args[0]: act_qspec,
                conv_node.args[1]: weight_qspec,
            },
            _annotated=True,
        )

        # See NOTE [training ir has no getitem for bn node].
        bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
            output_qspec=act_qspec,
            _annotated=True,
        )
        return model

    def validate(self, model: torch.fx.GraphModule):
        pass


class ConvBnDerivedBiasQuantizer(Quantizer):
    """
    Dummy quantizer that annotates conv bn in such a way that the bias qparams are
    derived from the conv input activation and weight qparams.
    """

    def __init__(self, is_per_channel: bool = False):
        super().__init__()
        self.is_per_channel = is_per_channel

    def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs):
        act_scale, _ = obs_or_fqs[0].calculate_qparams()
        weight_scale, _ = obs_or_fqs[1].calculate_qparams()
        if self.is_per_channel:
            bias_scale = act_scale * weight_scale
            bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
        else:
            bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
            bias_zero_point = torch.tensor([0], dtype=torch.int32)
        return bias_scale, bias_zero_point

    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
        if self.is_per_channel:
            weight_qscheme = torch.per_channel_symmetric
            weight_fq = FusedMovingAvgObsFakeQuantize.with_args(
                observer=MovingAveragePerChannelMinMaxObserver,
            )
        else:
            weight_qscheme = torch.per_tensor_affine
            weight_fq = default_fake_quant
        conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
        act_qspec = QuantizationSpec(
            dtype=torch.uint8,
            quant_min=0,
            quant_max=255,
            qscheme=torch.per_tensor_affine,
            observer_or_fake_quant_ctr=default_fake_quant,
        )
        weight_qspec = QuantizationSpec(
            dtype=torch.uint8,
            quant_min=0,
            quant_max=255,
            qscheme=weight_qscheme,
            observer_or_fake_quant_ctr=weight_fq,
        )
        bias_qspec = DerivedQuantizationSpec(
            derived_from=[
                (conv_node.args[0], conv_node),
                (conv_node.args[1], conv_node),
            ],
            derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams,
            dtype=torch.int32,
            quant_min=-(2**31),
            quant_max=2**31 - 1,
            qscheme=weight_qscheme,
            ch_axis=0 if self.is_per_channel else None,
        )
        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
            input_qspec_map={
                conv_node.args[0]: act_qspec,
                conv_node.args[1]: weight_qspec,
                conv_node.args[2]: bias_qspec,
            },
            _annotated=True,
        )

        # NOTE [training ir has no getitem for bn node].
        # getitem is None when we use the training IR. It outputs
        # aten.batch_norm.default, which do not need any getitem node.
        # In this case, we need to annotate on the batch norm node.
        # geteitem node should only be None if we are using training IR.

        bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
            output_qspec=act_qspec,
            _annotated=True,
        )
        return model

    def validate(self, model: torch.fx.GraphModule):
        pass


@skipIfNoQNNPACK
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EQATModels(PT2EQATTestCase):
    @skip_if_no_torchvision
    @skipIfNoQNNPACK
    def test_qat_resnet18(self):
        import torchvision

        with override_quantized_engine("qnnpack"):
            example_inputs = (torch.randn(1, 3, 224, 224),)
            m = torchvision.models.resnet18()
            self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)

    @skip_if_no_torchvision
    @skipIfNoQNNPACK
    def test_qat_mobilenet_v2(self):
        import torchvision

        with override_quantized_engine("qnnpack"):
            example_inputs = (torch.randn(1, 3, 224, 224),)
            m = torchvision.models.mobilenet_v2()
            self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)


@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
    class TwoLinear(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.linear1 = torch.nn.Linear(16, 8, bias=False)
            self.linear2 = torch.nn.Linear(8, 8)

        def forward(self, x):
            return self.linear2(self.linear1(x))

    class QATPTQTestModule(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3)
            self.linears = TestQuantizeMixQATAndPTQ.TwoLinear()
            self.my_linear = torch.nn.Linear(8, 8)

        def forward(self, x):
            conv_out = self.conv(x)
            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
            linear_out = self.linears(permute_out)
            my_linear_out = self.my_linear(linear_out)
            # Hardtanh doesnt get quantized via xnnpack quantizer in this test
            # because it relies on the propagation rules
            # Need to fix this
            return torch.nn.functional.hardtanh(my_linear_out)

    def _prepare_qat_linears(self, model):
        for name, child in model.named_children():
            if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)):
                if isinstance(child, torch.nn.Linear):
                    in_channels = child.weight.size(1)
                else:
                    in_channels = child.linear1.weight.size(1)

                example_input = (torch.rand((1, in_channels)),)
                traced_child = torch.export.export(
                    child, example_input, strict=True
                ).module()
                quantizer = XNNPACKQuantizer()
                quantization_config = get_symmetric_quantization_config(
                    is_per_channel=True, is_qat=True
                )
                quantizer.set_global(quantization_config)
                traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer)
                setattr(model, name, traced_child_prepared)
            else:
                self._prepare_qat_linears(child)

    def _convert_qat_linears(self, model):
        for name, child in model.named_children():
            if isinstance(child, torch.fx.GraphModule):
                torch.ao.quantization.move_exported_model_to_eval(child)
                converted_child = convert_pt2e(child)
                setattr(model, name, converted_child)
            else:
                self._convert_qat_linears(child)

    @unittest.skip("Failing with AssertionError: Guard failed: x.size()[0] == 1")
    def test_mixing_qat_ptq(self):
        example_inputs = (torch.randn(2, 3, 4, 4),)
        model = TestQuantizeMixQATAndPTQ.QATPTQTestModule()

        self._prepare_qat_linears(model)

        model(*example_inputs)
        # must be fixed model.eval()
        self._convert_qat_linears(model)
        model(*example_inputs)

        model_pt2e = torch.export.export(model, example_inputs, strict=True).module()

        quantizer = XNNPACKQuantizer()
        quantizer.set_module_type(torch.nn.Linear, None)
        quantization_config = get_symmetric_quantization_config()
        quantizer.set_global(quantization_config)
        model_pt2e = prepare_pt2e(model_pt2e, quantizer)
        after_prepare_result_pt2e = model_pt2e(*example_inputs)  # noqa: F841
        model_pt2e = convert_pt2e(model_pt2e)
        quant_result_pt2e = model_pt2e(*example_inputs)  # noqa: F841

        exported_model = torch.export.export(model_pt2e, example_inputs, strict=True)

        node_occurrence = {
            # conv2d: 1 for act, 1 for weight, 1 for output
            # 3 x linear: 1 for act, 1 for output
            ns.call_function(
                torch.ops.quantized_decomposed.quantize_per_tensor.default
            ): 8,
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_tensor.default
            ): 9,
            ns.call_function(
                torch.ops.quantized_decomposed.dequantize_per_channel.default
            ): 3,
            # There needs to be one for hardtanh
        }
        self.checkGraphModuleNodes(
            exported_model.graph_module, expected_node_occurrence=node_occurrence
        )


if __name__ == "__main__":
    run_tests()
